import numpy as np
from collections import defaultdict
import torch

class LinearValueProcessor:
    def __init__(self, start_eps, end_eps, end_eps_frames):
        self.start_eps = start_eps
        self.end_eps = end_eps
        self.end_eps_frames = end_eps_frames
    
    def __call__(self, frame):
        if frame >= self.end_eps_frames:
            return self.end_eps
        df = frame / self.end_eps_frames
        return df * self.end_eps + (1.0 - df) * self.start_eps

class DefaultRewardsShaper:
    def __init__(self, scale_value = 1, shift_value = 0, min_val=-np.inf, max_val=np.inf, log_val=False, is_torch=True, num_bins=100000, binning=False):
        self.scale_value = scale_value
        self.shift_value = shift_value
        self.min_val = min_val
        self.max_val = max_val
        self.log_val = log_val
        self.is_torch = is_torch
        self.binning = binning
        self.extras = {}
        
        if self.is_torch:
            self.log = torch.log
            self.clip = torch.clamp
            self.bin_counts = torch.zeros(num_bins, device='cuda')
            self.episode_bin_counts = torch.zeros((2048, num_bins), device='cuda')
            self.last_novelty = torch.zeros(2048, device='cuda')
        else:
            self.log = np.log
            self.clip = np.clip
            self.bin_counts = np.zeros(num_bins)
            self.episode_bin_counts = np.zeros((2048, num_bins))
            self.last_novelty = np.zeros(2048)

        # Simhash
        self.simhash_seed = torch.randint(0,1000, (1,), device='cpu').item()

    def compute_bin_from_progress(self, progress_vars, progress_directions):
        """
        Compute the binning from the progress variables and directions
        """
        # First, set min/max values if not already set
        # Also normalize progress vars in here
        for i in range(len(progress_directions)):
            if progress_directions[i]:
                if 'min' + str(i) not in self.extras:
                    self.extras['min' + str(i)] = torch.min(progress_vars[i])
                progress_vars[i] = torch.clamp((progress_vars[i] - self.extras['min' + str(i)]), min=0)
            else:
                if 'max' + str(i) not in self.extras:
                    self.extras['max' + str(i)] = torch.max(progress_vars[i])
                if self.extras['max' + str(i)] < 0:
                    progress_vars[i] = -1*progress_vars[i]
                else:
                    progress_vars[i] = torch.clamp((self.extras['max' + str(i)] - progress_vars[i]), min=0) / self.extras['max' + str(i)] # Now this is guaranteed to be in 0-1 range
        print("Extras", self.extras)
        # Progress is now associated with increasing values for both bins...
        # So we can generate an overall progress bin by just adding them together, with the appropriate granularity/scaling
        binning = torch.zeros(progress_vars[0].shape, dtype=torch.long, device=progress_vars[0].device)
        for i in range(len(progress_vars)):
            binning += ((progress_vars[i] * (1000 * (i == (len(progress_vars) - 1)) + 20)).long() % 10000)
        # Now generate bins from normalized vars
        return binning
    
    def simhash(self, x, proj):
        # ys: projected input
        ys = torch.matmul(x, proj.T) #(40, ,, 8)
        num_bins = torch.arange(ys.shape[-1], device = x.device)
        mask = (ys > 0).long()  # (40, ,,, 8)
        bits = mask.long() * (2 ** num_bins).long() #2^hash
        return bits.sum(dim=-1) #( 40, ...)

    def simhash_bins(self, obs, hashdim = 14): 
        g = torch.Generator(device = obs.device)
        g.manual_seed(self.simhash_seed)
        proj_matx = torch.randn((hashdim, obs.shape[-1]), generator=g, device = obs.device) #(8, 213) x (213, 8192, 40
        ys = self.simhash(obs, proj_matx) #(8,213)x (213, 8192) -> (8,8192)
        #print(torch.unique(ys, return_counts = True))
        return ys

    def __call__(self, reward, success=None, bins=None, relevant_features=None, dones=None):
        orig_reward = reward
        reward = reward + self.shift_value
        reward = reward * self.scale_value

        reward = self.clip(reward, self.min_val, self.max_val)

        if self.log_val:
            reward = self.log(reward)

        # Post-scaling, add intrinsic reward
        print("success = ", success, torch.mean(success), torch.max(success))
        novelD = False
        if True: #self.binning:
            # Check if bins is a tensor
            if type(bins) is torch.Tensor:
                obs_bins = bins # Cover the old logic
            else:
                # The progress option
                obs_bins = self.compute_bin_from_progress(bins[0], bins[1]) # For the new logic
            print("bins = ", obs_bins, obs_bins.shape, torch.min(obs_bins), torch.max(obs_bins), torch.mean(obs_bins.float()))
            # First boost the reward
            reward_boost = (1 / torch.sqrt(self.bin_counts[obs_bins] + 1))
            # Rescale reward_boost to be on average 0.01
            reward_boost = reward_boost / reward_boost.mean() * 0.001
            # Clip so that max reward boost is 0.1
            reward_boost = torch.clamp(reward_boost, 0, 0.1)
            # Make reward be sparse + intrinsic
            print("Shapes", reward.shape, success.shape, reward_boost.shape)
            reward[:,0] = success.float()*0.05 + reward_boost # Cap how big the success boost is?
            new_counts = torch.bincount(obs_bins, minlength=self.bin_counts.size(0))
            self.bin_counts += new_counts
        elif novelD:
            print("NovelD")
            # First set episde counts to zero where done
            self.episode_bin_counts[dones > 0] = 0
            # Compute bins
            obs_bins = self.compute_bin_from_progress(bins[0], bins[1])
            # Increment episode counts
            self.episode_bin_counts[torch.arange(obs_bins.shape[0]), obs_bins] += 1
            # Compute novelty
            novelty = 1 / torch.sqrt(self.bin_counts[obs_bins] + 1)
            # Impelement the novelD criterion
            alpha = 0.5
            novelD = (self.episode_bin_counts[torch.arange(obs_bins.shape[0]), obs_bins] == 1).float() * (novelty - alpha*self.last_novelty)
            # Clamp to be nonnegative
            novelD = torch.clamp(novelD, min=0)
            # Update last novelty
            self.last_novelty = novelty
            # Set reward
            reward_boost = novelD * 0.1 #(novelD / (novelD.mean() + 0.000001)) * 0.001
            reward[:,0] = success.float()*0.05 + reward_boost
            print("bins = ", obs_bins, obs_bins.shape, torch.min(obs_bins), torch.max(obs_bins), torch.mean(obs_bins.float()))
            #print("Novel count", torch.sum(self.episode_bin_counts[torch.arange(obs_bins.shape[0]), obs_bins] == 1)) # This is working fine
        return reward


def dicts_to_dict_with_arrays(dicts, add_batch_dim = True):
    def stack(v):
        if len(np.shape(v)) == 1:
            return np.array(v)
        else: 
            return np.stack(v)

    def concatenate(v):
        if len(np.shape(v)) == 1:
            return np.array(v)
        else: 
            return np.concatenate(v)

    dicts_len = len(dicts)
    if(dicts_len <= 1):
        return dicts
    res = defaultdict(list)
    { res[key].append(sub[key]) for sub in dicts for key in sub }
    if add_batch_dim:
        concat_func = stack
    else:
        concat_func = concatenate

    res = {k : concat_func(v)  for k,v in res.items()}
    return res

def unsqueeze_obs(obs):
    if type(obs) is dict:
        for k,v in obs.items():
            obs[k] = unsqueeze_obs(v)
    else:
        if len(obs.size()) > 1 or obs.size()[0] > 1:
            obs = obs.unsqueeze(0)
    return obs

def flatten_first_two_dims(arr):
    if arr.ndim > 2:
        return arr.reshape(-1, *arr.shape[-(arr.ndim-2):])
    else:
        return arr.reshape(-1)

def free_mem():
    import ctypes
    ctypes.CDLL('libc.so.6').malloc_trim(0) 